### auxiliary functions for Simultaneous Multilateral Search
###
### last edit: 220107
###=====================

## begin of module
module Utils

export transBetween
export arrowed_spines, twin_arrowed_spines, get_slope

###============================================================================


## arrowed_spines for plots: https://gist.github.com/joferkington/3845684
function arrowed_spines(ax; arrow_length=20, which="both", labels=("", ""), arrowprops=[])
    xlabel, ylabel = labels
    if isempty(arrowprops)
        arrowprops = Dict(:arrowstyle=>"<|-", :facecolor=>"black")
    end
    t = 0; xy = 0; xycoords = 0; xytext = ""; textcoords = 0; ha = 0; va = 0
    if which=="both"
        for (i,spine) in enumerate(["left", "bottom"])
            # Set up the annotation parameters
			t = ax.spines[spine].get_transform()
			#= t = getproperty(ax.spines,spine).get_transform() =#
            xy, xycoords = [.98, 0], ("axes fraction", t)
            xytext, textcoords = [arrow_length, 0], ("offset points", t)
            ha, va = "left", "bottom"
            if spine == "bottom"
                xarrow = ax.annotate(xlabel, xy, xycoords=xycoords, xytext=xytext, 
                    textcoords=textcoords, ha=ha, va="center", arrowprops=arrowprops)
            else
                yarrow = ax.annotate(ylabel, xy[end:-1:1], xycoords=xycoords[end:-1:1], 
                    xytext=xytext[end:-1:1], textcoords=textcoords[end:-1:1], 
                    ha="center", va=va, arrowprops=arrowprops)
            end
        end
    elseif which=="left" # only do the left vertical axis
        t = ax.spines[which].get_transform()
        xy, xycoords = [.98, 0], ("axes fraction", t)
        xytext, textcoords = [arrow_length, 0], ("offset points", t)
        yarrow = ax.annotate(ylabel, xy[end:-1:1], xycoords=xycoords[end:-1:1], 
            xytext=xytext[end:-1:1], textcoords=textcoords[end:-1:1], 
            ha="center", va="bottom", arrowprops=arrowprops)
    elseif which=="bottom" # only do the bottom horizontal axis
        t = ax.spines[which].get_transform()
        xy, xycoords = [.98, 0], ("axes fraction", t)
        xytext, textcoords = [arrow_length, 0], ("offset points", t)
        xarrow = ax.annotate(xlabel, xy, xycoords=xycoords, xytext=xytext, 
            textcoords=textcoords, ha="left", va="center", arrowprops=arrowprops)
	elseif which=="down" # only to the vertical axis pointing down
    end
end
## arrowed_spines for plots: https://gist.github.com/joferkington/3845684
function twin_arrowed_spines(ax; which_spines=["left","right"], arrow_length=20, labels=("", ""), color="black", arrowprops=[])
    xlabel, ylabel = labels
    if isempty(arrowprops)
        arrowprops = Dict(:arrowstyle=>"<|-", :facecolor=>color, :edgecolor=>color)
    end
    t = 0; xy = 0; xycoords = 0; xytext = ""; textcoords = 0; ha = 0; va = 0
    for (i,spine) in enumerate(which_spines)
        # Set up the annotation parameters
		t = ax.spines[spine].get_transform()
		#= t = getproperty(ax.spines,spine).get_transform() =#
        if spine == "left"
            xy, xycoords = [.98, 0], ("axes fraction", t)
            xytext, textcoords = [arrow_length, 0], ("offset points", t)
        else
            xy, xycoords = [.98, 1], ("axes fraction", t)
            xytext, textcoords = [arrow_length, 1], ("offset points", t)
        end
        ha, va = "left", "bottom"
        yarrow = ax.annotate(ylabel, xy[end:-1:1], xycoords=xycoords[end:-1:1], 
            xytext=xytext[end:-1:1], textcoords=textcoords[end:-1:1], 
            ha="center", va=va, arrowprops=arrowprops)
    end
end
## get the slope of a point
function get_slope(x, ay, ax)
    ## check if x is in range
    if x < ax[1]
        print("ERR: Cannot determine the slope at x = $x beyond min(x)")
        slope = NaN
        location = NaN
    elseif x > ax[end]
        print("ERR: Cannot determine the slope at x = $x beyond max(x)")
        slope = NaN
        location = NaN
    else
        x0 = maximum(ax[ax.<=x]) # get the starting point x0
        j = findall(ax.>=x0)[1] # get index
        location = [x0,ay[j]]
        if j == length(ay) # if the last index
            slope = 0.0
        else
            slope = (ay[j+1]-ay[j])/(ax[j+1]-ax[j])
        end
    end
    ## return
    return slope, location
end
###============================================================================


## transform between
function transBetween(x; reverse=false, lower=-Inf, upper=Inf, steepness=0.1)
    if isinf(lower) & isinf(upper) # no transformation
        return x
    elseif isinf(lower) & (~isinf(upper)) # upper bound
		if reverse
			y = x - upper
			return 0.5*steepness/y - 0.5*y/steepness
		else
			xx = steepness*(x + sqrt(x^2 + 1.0))
			return upper - xx
		end
    elseif (~isinf(lower)) & isinf(upper) # lower bound
		if reverse
			y = x - lower
			return -0.5*steepness/y + 0.5*y/steepness
		else
			#= xx = exp(steepness*x) =#
			xx = steepness*(x + sqrt(x^2 + 1.0))
			return lower + xx
		end
    else
		if reverse
			y = (x - lower)/(upper - x)
			return log(y)/steepness
		else
			#= xx = atan(steepness*x)/pi + 0.5 =#
			xx = 1.0/(1.0 + exp(-steepness*x))
			return lower + xx*(upper - lower)
		end
    end
end
###============================================================================



## end of module
end
